In [1]:
%load_ext autoreload
%autoreload 2
import os
os.chdir("../..")
from pathlib import Path
import logging
import pandas as pd
from prophet import Prophet
from prophet.diagnostics import cross_validation, performance_metrics
from prophet.plot import plot_plotly, plot_components_plotly
from src.config.config import *
from src.features.build_features import run_pipeline
from src.models.train_model import mass_forecaster
import pickle
logging.getLogger("prophet").setLevel(logging.ERROR)
logging.getLogger("cmdstanpy").disabled = True
In [2]:
conf = get_config()
df_train, df_test = run_pipeline(conf.DATA_PATH)
Loading config file: "./conf/config.yaml" Feature datasets have been saved!
In [3]:
df_train.head()
Out[3]:
| cat__Promo_1.0 | cat__SchoolHoliday_1.0 | y | ||
|---|---|---|---|---|
| Store | ds | |||
| 1 | 2013-01-02 | 0.0 | 1.0 | 5530 |
| 2013-01-03 | 0.0 | 1.0 | 4327 | |
| 2013-01-04 | 0.0 | 1.0 | 4486 | |
| 2013-01-05 | 0.0 | 1.0 | 4997 | |
| 2013-01-07 | 1.0 | 1.0 | 7176 |
In [4]:
df_test.head()
Out[4]:
| cat__Promo_1.0 | cat__SchoolHoliday_1.0 | ||
|---|---|---|---|
| Store | ds | ||
| 1 | 2015-08-01 | 0.0 | 1.0 |
| 2015-08-02 | 0.0 | 1.0 | |
| 2015-08-03 | 1.0 | 1.0 | |
| 2015-08-04 | 1.0 | 1.0 | |
| 2015-08-05 | 1.0 | 1.0 |
Run cross-validation with backfitting, save best model¶
In [5]:
# param_grid to optimize over
print(conf.PARAM_GRID)
{'changepoint_prior_scale': [0.001, 0.01, 0.1, 0.5], 'seasonality_mode': ['additive', 'multiplicative'], 'seasonality_prior_scale': [0.01, 0.1, 1.0, 10.0]}
In [6]:
# Cross-validate and backfit
mass_forecaster(conf)
Starting forecasting procedure for Store:1
Best params are {'changepoint_prior_scale': 0.1, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 0.01}
Best rmse is : 623.84
Starting forecasting procedure for Store:3
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 1023.23
Starting forecasting procedure for Store:7
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 1595.29
Starting forecasting procedure for Store:8
Best params are {'changepoint_prior_scale': 0.1, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 813.30
Starting forecasting procedure for Store:9
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 10.0}
Best rmse is : 1068.19
Starting forecasting procedure for Store:10
Best params are {'changepoint_prior_scale': 0.5, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.01}
Best rmse is : 695.85
Starting forecasting procedure for Store:11
Best params are {'changepoint_prior_scale': 0.1, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 0.01}
Best rmse is : 1278.95
Starting forecasting procedure for Store:12
Best params are {'changepoint_prior_scale': 0.5, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 1.0}
Best rmse is : 1285.59
Starting forecasting procedure for Store:13
Best params are {'changepoint_prior_scale': 0.001, 'seasonality_mode': 'additive', 'seasonality_prior_scale': 0.01}
Best rmse is : 1099.03
Starting forecasting procedure for Store:14
Best params are {'changepoint_prior_scale': 0.01, 'seasonality_mode': 'multiplicative', 'seasonality_prior_scale': 0.1}
Best rmse is : 698.78
Load single saved model and predict¶
In [7]:
import pickle
# stores i.e. 1,3,7...
store = 1
with (conf.MODEL_PATH / "saved_models" / f"{str(store)}.pkl").open("rb") as handle:
model_1 = pickle.load(handle)
Visualizations¶
In [8]:
plot_plotly(
model_1,
model_1.predict(
pd.concat(
[df_train.loc[store].reset_index(), df_test.loc[store].reset_index()],
axis=0,
)
),
)
In [9]:
plot_components_plotly(model_1, model_1.predict(df_test.reset_index()))
In [10]:
## View results from time series backfitting grid search
pd.read_csv("models/results/tuning_results.csv", index_col=[0])
Out[10]:
| store | changepoint_prior_scale | seasonality_mode | seasonality_prior_scale | rmse | |
|---|---|---|---|---|---|
| 0 | 1 | 0.001 | additive | 0.01 | 758.632796 |
| 1 | 1 | 0.001 | additive | 0.10 | 799.090963 |
| 2 | 1 | 0.001 | additive | 1.00 | 739.941039 |
| 3 | 1 | 0.001 | additive | 10.00 | 786.382349 |
| 4 | 1 | 0.001 | multiplicative | 0.01 | 782.214540 |
| ... | ... | ... | ... | ... | ... |
| 27 | 14 | 0.500 | additive | 10.00 | 733.883589 |
| 28 | 14 | 0.500 | multiplicative | 0.01 | 773.275699 |
| 29 | 14 | 0.500 | multiplicative | 0.10 | 736.575138 |
| 30 | 14 | 0.500 | multiplicative | 1.00 | 737.317667 |
| 31 | 14 | 0.500 | multiplicative | 10.00 | 737.817153 |
320 rows × 5 columns
In [11]:
# View forecasts
pd.read_csv("models/results/forecasts.csv", index_col=[0])
Out[11]:
| store | ds | trend | yhat_lower | yhat_upper | trend_lower | trend_upper | additive_terms | additive_terms_lower | additive_terms_upper | weekly | weekly_lower | weekly_upper | yearly | yearly_lower | yearly_upper | multiplicative_terms | multiplicative_terms_lower | multiplicative_terms_upper | yhat | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | 1 | 2013-01-02 | 5284.447944 | 4437.646265 | 6615.198120 | 5284.447944 | 5284.447944 | 225.284287 | 225.284287 | 225.284287 | -168.384685 | -168.384685 | -168.384685 | 393.668971 | 393.668971 | 393.668971 | 0.000000 | 0.000000 | 0.000000 | 5509.732231 |
| 1 | 1 | 2013-01-03 | 5282.426092 | 4205.893135 | 6439.996854 | 5282.426092 | 5282.426092 | 47.561273 | 47.561273 | 47.561273 | -243.860775 | -243.860775 | -243.860775 | 291.422048 | 291.422048 | 291.422048 | 0.000000 | 0.000000 | 0.000000 | 5329.987365 |
| 2 | 1 | 2013-01-04 | 5280.404241 | 4391.347974 | 6492.928135 | 5280.404241 | 5280.404241 | 168.446565 | 168.446565 | 168.446565 | -25.282493 | -25.282493 | -25.282493 | 193.729058 | 193.729058 | 193.729058 | 0.000000 | 0.000000 | 0.000000 | 5448.850806 |
| 3 | 1 | 2013-01-05 | 5278.382389 | 4433.331538 | 6531.860868 | 5278.382389 | 5278.382389 | 253.192242 | 253.192242 | 253.192242 | 151.482467 | 151.482467 | 151.482467 | 101.709775 | 101.709775 | 101.709775 | 0.000000 | 0.000000 | 0.000000 | 5531.574631 |
| 4 | 1 | 2013-01-07 | 5274.338685 | 4514.679064 | 6698.863744 | 5274.338685 | 5274.338685 | 290.639197 | 290.639197 | 290.639197 | 352.115898 | 352.115898 | 352.115898 | -61.476701 | -61.476701 | -61.476701 | 0.000000 | 0.000000 | 0.000000 | 5564.977882 |
| ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
| 43 | 14 | 2015-09-13 | 5691.851317 | 3896.936149 | 6524.159405 | 5691.850807 | 5691.851818 | 0.000000 | 0.000000 | 0.000000 | 0.019224 | 0.019224 | 0.019224 | -0.105064 | -0.105064 | -0.105064 | -0.085840 | -0.085840 | -0.085840 | 5203.263804 |
| 44 | 14 | 2015-09-14 | 5692.133133 | 5163.000069 | 7668.927595 | 5692.132595 | 5692.133666 | 0.000000 | 0.000000 | 0.000000 | 0.234822 | 0.234822 | 0.234822 | -0.098604 | -0.098604 | -0.098604 | 0.136218 | 0.136218 | 0.136218 | 6467.503347 |
| 45 | 14 | 2015-09-15 | 5692.414949 | 4300.793334 | 6785.913349 | 5692.414394 | 5692.415531 | 0.000000 | 0.000000 | 0.000000 | 0.068748 | 0.068748 | 0.068748 | -0.091335 | -0.091335 | -0.091335 | -0.022587 | -0.022587 | -0.022587 | 5563.837657 |
| 46 | 14 | 2015-09-16 | 5692.696765 | 3926.175141 | 6464.533322 | 5692.696186 | 5692.697366 | 0.000000 | 0.000000 | 0.000000 | -0.011180 | -0.011180 | -0.011180 | -0.083386 | -0.083386 | -0.083386 | -0.094565 | -0.094565 | -0.094565 | 5154.364468 |
| 47 | 14 | 2015-09-17 | 5692.978581 | 4075.649832 | 6669.954978 | 5692.977983 | 5692.979201 | 0.000000 | 0.000000 | 0.000000 | 0.004806 | 0.004806 | 0.004806 | -0.074900 | -0.074900 | -0.074900 | -0.070093 | -0.070093 | -0.070093 | 5293.938169 |
8141 rows × 20 columns